#This repository is the official implementation
#of ICLR 2025 submitted paper entitled
#Transfer Learning Under High-Dimensional Graph
#Convolutional Regression Model for Node
#Classification
############################################
################dependencies################
############################################
rm(list=ls())
library(glmnet)
library(foreach)
library(doParallel)
library(R.matlab)
library(Matrix)
library(caret)
library(igraph)
library(ROCR)
library(doParallel)
library(glmnet)
library(randnet)
library(graphon)
library(Matrix)
library(pheatmap)
library(viridisLite)
library(gridExtra)
library(rapportools)

############################################
################functions##################
############################################
#Algorithm 1
transgcmlr <- function(target, source = NULL, family = "multinomial",
                       transfer.source.id = "all", alpha = 1, standardize = FALSE, intercept = FALSE,
                       nfolds = 5, cores = 1, valid.proportion = NULL, valid.nfolds = 3,
                       lambda = c(transfer = "lambda.1se", debias = "lambda.min", detection = "lambda.1se"),
                       detection.info = TRUE, target.weights = NULL, source.weights = NULL, C0 = 2, ...) {
  
  transfer.source.id.ori <- transfer.source.id
  data <- c(target, source)
  k <- NULL
  lambda.fit <- c(transfer = NA, debias = NA)
  
  if (is.null(target.weights)) {
    target.weights <- rep(1, length(target$y))
  }
  
  if (is.null(source.weights)) {
    source.weights <- sapply(1:length(source), function(i){
      rep(1, length(source[[i]]$y))
    }, simplify = FALSE)
  }
  
  
  if (!is.null(source) && (is.string(transfer.source.id) && transfer.source.id == "all")) { # transfer all source data
    transfer.source.id <- 1:length(source)
  } else if (!is.null(source) && is.string(transfer.source.id) && transfer.source.id == "auto") { # automatically check which source data set to transfer
    A <- source_detection(target = target, family = family, source = source, alpha = alpha,
                          cores = cores,  lambda = lambda["detection"], valid.proportion = valid.proportion,
                          valid.nfolds = valid.nfolds, detection.info = detection.info, standardize = standardize,
                          intercept = intercept, nfolds = nfolds, target.weights = target.weights, source.weights = source.weights, C0 = C0, ...)
    
    transfer.source.id <- A$transfer.source.id
    
  } else if (0 %in% transfer.source.id || is.null(source)) { # don't transfer any source
    transfer.source.id <- 0
  }
  
  all.x <- as.matrix(foreach(k = unique(c(transfer.source.id)), .combine = "rbind") %do% {
    if (k != 0) {
      source[[k]]$x
    } else {
      target$x
    }
  })
  
  all.y <- foreach(k = unique(c(transfer.source.id)), .combine = "c") %do% {
    if (k != 0) {
      source[[k]]$y
    } else {
      target$y
    }
  }
  
  p <- ncol(all.x)
  
  
  if (cores > 1) {
    registerDoParallel(cores)
  }
  w <- Reduce("c", sapply(unique(c(transfer.source.id)), function(k){
    if (k != 0) {
      source.weights[[k]]
    } else {
      target.weights
    }
  }, simplify = FALSE))
  
  #step 1
  n.try <- 0
  while (T) {
    cv.fit.trans <- try(cv.glmnet(x = all.x, y = all.y, type.measure="class", nfolds = 5,family = family, alpha = 1,lambda.min.ratio = 0.01,
                                  intercept=FALSE , standardize=FALSE ), silent = T)
    if (class(cv.fit.trans) != "try-error") {
      break
    }
    n.try <- n.try + 1
    print(paste("tried", n.try, "times during Step 1"))
    if (n.try > 10) {
      stop("errors during Step 1!!!")
    }
  }
  cv.fit.trans1se=glmnet(x = all.x, y = all.y, type.measure="class",family = family, alpha = 1,
                         intercept=FALSE, standardize=FALSE,lambda=cv.fit.trans$lambda.1se)
  
  wa <- as.matrix(do.call(cbind, coef(cv.fit.trans1se)))
  
  
  # step 2
  
  offset <- (as.matrix(target$x) %*% wa[-1,])+wa[1,]
  
  n.try <- 0
  while (T) {
    cv.fit.correct <- try(cv.glmnet(x = as.matrix(target$x), y = target$y, type.measure="class", nfolds=5,family = family, alpha = 1,offset = offset,lambda.min.ratio = 0.01,
                                    intercept = FALSE, standardize = FALSE), silent=TRUE)
    if (class(cv.fit.correct) != "try-error") {
      break
    }
    n.try <- n.try + 1
    
    if (n.try > 10) {
      stop("Errors occur during the Step 2!!!")
    }
  }
  
  deltaa <- as.matrix(do.call(cbind, coef(cv.fit.correct)))
  
  beta.hat <- wa + deltaa
  
  if(cores > 1) {
    stopImplicitCluster()
  }
  
  return(beta.hat)
}

createFolds_multinomial <- function(y, k) {
  # Get the unique class levels in y
  levels_y <- range(y)+1
  
  # Create a list to store folds for each level
  folds_list <- lapply(1:levels_y, function(level) {
    ind <- which(y == level)
    createFolds(ind, k = k)
  })
  
  ind_list<- lapply(1:levels_y, function(level) {
    ind <- which(y == level)
    ind})
  
  for(i in 1:levels_y)
  {
    for (j in 1:k)
    {
      folds_list[[i]][[j]]=ind_list[[i]][unlist(folds_list[[i]][j])]
    }
    
  }
  
  folds_combined <- vector("list", k)
  for (i in 1:k) {
    folds_combined[[i]] <- unlist(lapply(folds_list, function(fold) fold[[i]]), use.names = FALSE)
  }
  
  folds_combined
}


#Algorithm 2
source_detection <- function(target, source = NULL, family = "multinomial", alpha = 1, standardize = FALSE,
                             intercept = FALSE, nfolds = 5, cores = 1, valid.nfolds =3,
                             lambda = "lambda.1se", detection.info = TRUE, target.weights = NULL, source.weights = NULL, C0 = 2, ...) {
  
  if (is.null(target.weights)) {
    target.weights <- rep(1, length(target$y))
  }
  
  if (is.null(source.weights)) {
    source.weights <- sapply(1:length(source), function(i){
      rep(1, length(source[[i]]$y))
    }, simplify = FALSE)
  }
  
  
  
  if (cores > 1) {
    registerDoParallel(cores)
  }
  
  
  folds <- createFolds_multinomial(target$y, valid.nfolds)
  
  wa=NULL
  NL=NULL
  for(k in 1:length(source))
  {
    NL[[k]]=0
    for(i in 1:valid.nfolds)
    {
      target_fold=list(x = NULL, y = NULL)
      target_fold$x=target$x[-folds[[i]], , drop = F]
      target_fold$y=target$y[-folds[[i]]]
      source_k=NULL
      source_k[[1]]$x=source[[k]]$x
      source_k[[1]]$y=source[[k]]$y
      wa[[i]]=transgcmlr_detect(target_fold,source_k)
      
      beta_sgc=wa[[i]]
      testing=NULL
      testing$x=target$x[folds[[i]], , drop = F]
      testing$y=target$y[folds[[i]]]
      logits = as.matrix(testing$x) %*% beta_sgc[-1,]
      probs <- apply(logits, 1, function(x) exp(x) / sum(exp(x)))
      
      for(mm in 1:dim(probs)[2])
      {
        NL[[k]]=NL[[k]]+probs[testing$y[mm],mm]
      }
    }
  }
  
  
  
  
  if(cores > 1) {
    stopImplicitCluster()
  }
  
  
  obj <- list(NL=NL)
  
  return(obj)
}


transgcmlr_detect <- function(target, source = NULL, family = "multinomial",
                              transfer.source.id = "all", alpha = 1, standardize = FALSE, intercept = FALSE,
                              nfolds = 5, cores = 1,cus_value=0.5, valid.proportion = NULL, valid.nfolds = 3,
                              lambda = c(transfer = "lambda.1se", debias = "lambda.min", detection = "lambda.1se"),
                              detection.info = TRUE, target.weights = NULL, source.weights = NULL, C0 = 2, ...) {
  
  
  transfer.source.id.ori <- transfer.source.id
  data <- c(target, source) # to be updated
  k <- NULL
  lambda.fit <- c(transfer = NA, debias = NA)
  
  if (is.null(target.weights)) {
    target.weights <- rep(1, length(target$y))
  }
  
  if (is.null(source.weights)) {
    source.weights <- sapply(1:length(source), function(i){
      rep(1, length(source[[i]]$y))
    }, simplify = FALSE)
  }
  
  
  if (!is.null(source) && (is.string(transfer.source.id) && transfer.source.id == "all")) { # transfer all source data
    transfer.source.id <- 1:length(source)
  } else if (!is.null(source) && is.string(transfer.source.id) && transfer.source.id == "auto") { # automatically check which source data set to transfer
    A <- source_detection(target = target, family = family, source = source, alpha = alpha,
                          cores = cores,  lambda = lambda["detection"], valid.proportion = valid.proportion,
                          valid.nfolds = valid.nfolds, detection.info = detection.info, standardize = standardize,
                          intercept = intercept, nfolds = nfolds, target.weights = target.weights, source.weights = source.weights, C0 = C0, ...)
    
    transfer.source.id <- A$transfer.source.id
  } else if (0 %in% transfer.source.id || is.null(source)) { # don't transfer any source
    transfer.source.id <- 0
  }
  
  
  # step 1
  # --------------------------------------
  all.x <- as.matrix(foreach(k = unique(c(transfer.source.id)), .combine = "rbind") %do% {
    if (k != 0) {
      source[[k]]$x
    } else {
      target$x
    }
  })
  
  all.y <- foreach(k = unique(c(transfer.source.id)), .combine = "c") %do% {
    if (k != 0) {
      source[[k]]$y
    } else {
      target$y
    }
  }
  
  p <- ncol(all.x)
  
  
  if (cores > 1) {
    registerDoParallel(cores)
  }
  w <- Reduce("c", sapply(unique(c(0, transfer.source.id)), function(k){
    if (k != 0) {
      source.weights[[k]]
    } else {
      target.weights
    }
  }, simplify = FALSE))
  n.try <- 0
  while (T) {
    cv.fit.trans <- try(cv.glmnet(x = all.x, y = all.y, type.measure="class", nfolds = 5,family = family, alpha = 1,lambda.min.ratio = 0.01,
                                  intercept=FALSE , standardize=FALSE ), silent = T)
    if (class(cv.fit.trans) != "try-error") {
      break
    }
    n.try <- n.try + 1
    print(paste("tried", n.try, "times during step 1!"))
    if (n.try > 10) {
      stop("errors during step 1!!!")
    }
  }
  cv.fit.trans1se=glmnet(x = all.x, y = all.y, type.measure="class",family = family, alpha = 1,
                         intercept=FALSE, standardize=FALSE,lambda=cv.fit.trans$lambda.1se)
  
  wa <- as.matrix(do.call(cbind, coef(cv.fit.trans1se)))
  # step 2
  
  offset <- (as.matrix(target$x) %*% wa[-1,])+wa[1,]
  
 
  if(is.null(cus_value)){
  n.try <- 0
  while (T) {
    cv.fit.correct <- try(cv.glmnet(x = as.matrix(target$x), y = target$y, type.measure="class", nfolds=5,family = family, alpha = 1,offset = offset,lambda.min.ratio = 0.01,
                                    intercept = FALSE, standardize = FALSE), silent=TRUE)
    if (class(cv.fit.correct) != "try-error") {
      break
    }
    n.try <- n.try + 1
    
    if (n.try > 10) {
      stop("Errors occur during the Step 2!!!")
    }
  }
  }
  else{
    cv.fit.correct =glmnet(x = as.matrix(target$x), y = target$y, weights = target.weights, offset = offset, family = family, alpha =1,
                           intercept = FALSE, standardize = FALSE,lambda=cus_value,maxit=10^6)
  }
  
  deltaa <- as.matrix(do.call(cbind, coef(cv.fit.correct)))
  
  beta.hat <- wa + deltaa
  
  
  
  
  if(cores > 1) {
    stopImplicitCluster()
  }
  
  
  return(beta.hat)
}


# Normalize A
normalize <- function(A) {

  if (nrow(A) != ncol(A)) {
    stop("Adjacency matrix must be square.")
  }
  
 
  degrees <- rowSums(A)
  
  
  D <- diag(degrees^(-0.5))
  S=D%*%A%*%D
  return(S)
}
